#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

#################
#Utilities for simpler ONNX graph export. Should go away with  newr version of PyTorch (>1.6) and OpSet>=11
#################

import os
import shutil
import numpy as np
import cv2
import torch
from .. import layers as xtensor_layers
from .image_utils import *
from .print_utils import *
import torch.nn.functional as F
from .count_flops import *

def interpolate_simplified(input_interpolate, size=None, scale_factor=None, mode='nearest', align_corners=None,
                recompute_scale_factor=False, params_simplified=False, params_opset=9):
    version = torch.__version__.split('.')
    torch_version = 100 * int(version[0]) + 10 * int(version[1]) + int(version[2])
    print_once("torch_version: ", torch_version)
    if torch_version > 120:
        return interpolate_simplified_ptver_gt120(input_interpolate, size=size, scale_factor=scale_factor, mode=mode,
                                   align_corners=align_corners, recompute_scale_factor=recompute_scale_factor,
                                   params_simplified=params_simplified, params_opset=params_opset)
    else:
        return interpolate_simplified_ptver_lte120(input_interpolate, size=size, scale_factor=scale_factor, mode=mode,
                                      align_corners=align_corners, recompute_scale_factor=recompute_scale_factor,
                                      params_simplified=params_simplified, params_opset=params_opset)

#for torch version > 1.2
def interpolate_simplified_ptver_gt120(input_interpolate, size=None, scale_factor=None, mode='nearest', align_corners=None,
                recompute_scale_factor=False, params_simplified=False, params_opset=9):

    if params_simplified:
        align_corners = False if params_opset == 9 else align_corners

        if size is not None:
            scale_factor = [float(size[0]) / float(input_interpolate.shape[2]), float(size[1]) / float(input_interpolate.shape[3])]
        return F.interpolate(input_interpolate, scale_factor=scale_factor, mode=mode, align_corners=align_corners,
                         recompute_scale_factor=recompute_scale_factor)
    else:
        return F.interpolate(input_interpolate, scale_factor=scale_factor, size=size, mode=mode,
                            align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)

# for torch version <= 1.2
# Torch1.1 does not have recompute_scale_factor as parameter. so need separate function.
def interpolate_simplified_ptver_lte120(input_interpolate, size=None, scale_factor=None, mode='nearest', align_corners=None,
                recompute_scale_factor=False, params_simplified=False, params_opset=9):

    if params_simplified:
        align_corners = False if params_opset == 9 else align_corners

        if not mode in ['linear', 'bilinear', 'bicubic', 'trilinear']:
            align_corners = None

        if size is not None:
            scale_factor = [float(size[0]) / float(input_interpolate.shape[2]), float(size[1]) / float(input_interpolate.shape[3])]
        return F.interpolate(input_interpolate, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
    else:
        return F.interpolate(input_interpolate, scale_factor=scale_factor, size=size, mode=mode,
                            align_corners=align_corners)


def comp_latency(inputNetFile=None, tidl_params=None):
    import os
    from pytidl.model.pytidl_model import PyTIDLModel

    inWidth = 224
    inHeight = 224

    if tidl_params is not None:
        outDataNamesList = tidl_params['outDataNamesList'] if 'outDataNamesList' in tidl_params else None
        tidl_dir = tidl_params['tidl_dir'] if 'tidl_dir' in tidl_params else None
        save_files = tidl_params['save_tidl_files'] if 'save_tidl_files' in tidl_params else None
        BATCH_SIZE = tidl_params['BATCH_SIZE'] if 'BATCH_SIZE' in tidl_params else 64

    perfSimConfig = dict(MSMCSIZE_KB=796800, BATCH_SIZE=BATCH_SIZE, MATPANELSIZE=64)
    filename, file_extension = os.path.splitext(inputNetFile)
    workDir = filename + "-perf-sim-{}x{}-bs{}".format(inWidth, inHeight, BATCH_SIZE) if save_files else None

    if workDir is not None and os.path.exists(workDir):
        print("removing old dir structure : {}".format(workDir))
        shutil.rmtree(workDir, ignore_errors=True)

    #default is 577 = 0x241, subtract 1 for removing Nci constrain
    compileConstraintsFlag = 577-1
    pytidl_model = PyTIDLModel(tidlDir=tidl_dir,
                            inputNetFile=inputNetFile,
                            #inputParamsFile=inputParamsFileTemp,
                            inMean=(123.675, 116.28, 103.53),
                            inScale=(0.017125, 0.017507, 0.017429),
                            inData='./examples/inputs',
                            workDir=workDir,
                            resizeWidth=256,
                            resizeHeight=256,
                            inWidth=inWidth,
                            inHeight=inHeight,
                            numParamBits=8,
                            debugTraceLevel=1,
                            outDataNamesList = outDataNamesList,
                            #inDataNamesList=inDataNamesList,
                            perfSimConfig=perfSimConfig,
                            numFrames=1,
                            enableHighResOptimization=True,
                            compileConstraintsFlag=compileConstraintsFlag)

    perfsim_dict = pytidl_model.perfsim_data()
    print(perfsim_dict['cycles'], perfsim_dict['cycles_unit'], perfsim_dict['time'], perfsim_dict['time_unit'] )

def export_models(net=None, write_onnx=False, write_torch_script=True, compute_flops=False, op_path=None,
                model_name='model_name', ip_tensor_size=[2, 3, 1024, 1024], opset_version=9, device='cpu', 
				log_func=print, model_summary=False, do_constant_folding=True, en_comp_latency=False,
                print_model=True, tidl_params=None, compute_fw_bw_flops=False, dummy_input=None):

    is_cuda = next(net.parameters()).is_cuda
    device = 'cuda' if is_cuda else 'cpu'
    if print_model:
        log_func(net)
    sample_batch_size, channel, height, width = ip_tensor_size
    if dummy_input is None:
        dummy_input = torch.randn(sample_batch_size, channel, height, width, device=device)
    if compute_flops:
        flops = forward_count_flops(net, dummy_input)
        #print("GFlops : {:.2f} GMAC : {:.2f} ".format(flops/1E9, (flops/(2*1E9))))
        log_func("GFlops : {:.2f} GMAC : {:.2f} ".format(flops/1E9, (flops/(2*1E9))))

    if compute_fw_bw_flops:
        num_fw_bw_flops, op_vol = fw_bw_count_flops(net, dummy_input)
        print("Num fw + bw ops  in G: {}, op_vol in MB : {}".format(num_fw_bw_flops/1E9, op_vol/1E6))

    version = torch.__version__.split('.')
    torch_version = 100*int(version[0]) + 10*int(version[1]) + int(version[2])
    print_once("torch_version: ", torch_version)
    if write_torch_script and torch_version >= 110:
        # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
        #net = torch.nn.DataParallel(net)
        traced_script_module = torch.jit.trace(net, dummy_input)
        pt_file_name = os.path.join(op_path, "{}_ts.pth".format(model_name))
        log_func("saving {}".format(pt_file_name))
        torch.jit.save(traced_script_module, pt_file_name)

    if write_onnx:
        onnx_file_name = os.path.join(op_path, "{}.onnx".format(model_name))
        log_func(onnx_file_name)
        if model_summary:
            from torchsummary import summary
            log_func("Model Summary Start")
            summary(net, dummy_input.shape[1:], device=device)
            log_func("Model Summary End")
        if torch_version >= 150:    
            torch.onnx.export(net,dummy_input, onnx_file_name, do_constant_folding=do_constant_folding, opset_version=opset_version,
                enable_onnx_checker=False)
        else:        
            torch.onnx.export(net,dummy_input, onnx_file_name, do_constant_folding=do_constant_folding, opset_version=opset_version)
        log_func("export completed")

    if en_comp_latency and write_onnx:
        comp_latency(inputNetFile=onnx_file_name, tidl_params=tidl_params)

if __name__ == '__main__':
    main() 
